{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RD3uxzaJweYr"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "C-vBUz5IhJs8"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pHTibyMehTvH"
      },
      "source": [
        "# Tutorial: Convert models trained using TensorFlow Object Detection API to TensorFlow Lite\n",
        "\n",
        "This tutorial demonstrate these steps:\n",
        "* Convert TensorFlow models trained using the TensorFlow Object Detection API to [TensorFlow Lite](https://www.tensorflow.org/lite).\n",
        "* Add the required metadata using [TFLite Metadata Writer API](https://www.tensorflow.org/lite/convert/metadata_writer_tutorial#object_detectors). This will make the TFLite model compatible with [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector), so that the model can be integrated in mobile apps in 3 lines of code."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QIR1IFpnLJJA"
      },
      "source": [
        "\u003ctable align=\"left\"\u003e\u003ctd\u003e\n",
        "  \u003ca target=\"_blank\"  href=\"https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\n",
        "  \u003c/a\u003e\n",
        "\u003c/td\u003e\u003ctd\u003e\n",
        "  \u003ca target=\"_blank\"  href=\"https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/convert_odt_model_to_TFLite.ipynb\"\u003e\n",
        "    \u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "\u003c/td\u003e\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ok_Rpv7XNaFJ"
      },
      "source": [
        "## Preparation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t7CAW5C1cmel"
      },
      "source": [
        "### Install the TFLite Support Library"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DwtFa0jSnNU4"
      },
      "outputs": [],
      "source": [
        "!pip install -q tflite_support"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XRfJR9QXctAR"
      },
      "source": [
        "### Install the TensorFlow Object Detection API\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7PP2P5XAqeI5"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import pathlib\n",
        "\n",
        "# Clone the tensorflow models repository if it doesn't already exist\n",
        "if \"models\" in pathlib.Path.cwd().parts:\n",
        "  while \"models\" in pathlib.Path.cwd().parts:\n",
        "    os.chdir('..')\n",
        "elif not pathlib.Path('models').exists():\n",
        "  !git clone --depth 1 https://github.com/tensorflow/models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bP6SSh6zqi07"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "cd models/research/\n",
        "protoc object_detection/protos/*.proto --python_out=.\n",
        "cp object_detection/packages/tf2/setup.py .\n",
        "pip install -q ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i0to7aXKc0O9"
      },
      "source": [
        "### Import the necessary libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4M8CC1PgqnSf"
      },
      "outputs": [],
      "source": [
        "import matplotlib\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import os\n",
        "import random\n",
        "import io\n",
        "import imageio\n",
        "import glob\n",
        "import scipy.misc\n",
        "import numpy as np\n",
        "from six import BytesIO\n",
        "from PIL import Image, ImageDraw, ImageFont\n",
        "from IPython.display import display, Javascript\n",
        "from IPython.display import Image as IPyImage\n",
        "\n",
        "import tensorflow as tf\n",
        "\n",
        "from object_detection.utils import label_map_util\n",
        "from object_detection.utils import config_util\n",
        "from object_detection.utils import visualization_utils as viz_utils\n",
        "from object_detection.utils import colab_utils\n",
        "from object_detection.utils import config_util\n",
        "from object_detection.builders import model_builder\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s9WIOOMTNti5"
      },
      "source": [
        "## Download a pretrained model from Model Zoo\n",
        "\n",
        "In this tutorial, we demonstrate converting a pretrained model `SSD MobileNet V2 FPNLite 640x640` in the [TensorFlow 2 Model Zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md). You can replace the model with your own model and the rest will work the same."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TIY3cxDgsxuZ"
      },
      "outputs": [],
      "source": [
        "!wget http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz\n",
        "!tar -xf ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz\n",
        "!rm ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8.tar.gz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0gV8vr6nN-z9"
      },
      "source": [
        "## Generate TensorFlow Lite Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z8FjeSmmxpXz"
      },
      "source": [
        "### Step 1: Export TFLite inference graph\n",
        "\n",
        "First, we invoke `export_tflite_graph_tf2.py` to generate a TFLite-friendly intermediate SavedModel. This will then be passed to the TensorFlow Lite Converter for generating the final model.\n",
        "\n",
        "Use `--help` with the above script to get the full list of supported parameters.\n",
        "These can fine-tune accuracy and speed for your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ChfN-tzBXqko"
      },
      "outputs": [],
      "source": [
        "!python models/research/object_detection/export_tflite_graph_tf2.py \\\n",
        "    --trained_checkpoint_dir {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/checkpoint'} \\\n",
        "    --output_directory {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite'} \\\n",
        "    --pipeline_config_path {'ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/pipeline.config'}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IPr06cZ3OY3H"
      },
      "source": [
        "### Step 2: Convert to TFLite\n",
        "\n",
        "Use the [TensorFlow Lite Converter](https://www.tensorflow.org/lite/convert) to\n",
        "convert the `SavedModel` to TFLite. Note that you need to use `from_saved_model`\n",
        "for TFLite conversion with the Python API.\n",
        "\n",
        "You can also leverage\n",
        "[Post-training Quantization](https://www.tensorflow.org/lite/performance/post_training_quantization)\n",
        "to\n",
        "[optimize performance](https://www.tensorflow.org/lite/performance/model_optimization)\n",
        "and obtain a smaller model. In this tutorial, we use the [dynamic range quantization](https://www.tensorflow.org/lite/performance/post_training_quant)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JMpy3Rlpq-Yq"
      },
      "outputs": [],
      "source": [
        "_TFLITE_MODEL_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/model.tflite\"\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model('ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite/saved_model')\n",
        "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "with open(_TFLITE_MODEL_PATH, 'wb') as f:\n",
        "  f.write(tflite_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fyjlnmaEOtKp"
      },
      "source": [
        "### Step 3: Add Metadata\n",
        "\n",
        "The model needs to be packed with [TFLite Metadata](https://www.tensorflow.org/lite/convert/metadata) to enable easy integration into mobile apps using the [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector). This metadata helps the inference code perform the correct pre \u0026 post processing as required by the model. Use the following code to create the metadata."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-ecGLG_Ovjcr"
      },
      "outputs": [],
      "source": [
        "# Download the COCO dataset label map that was used to trained the SSD MobileNet V2 FPNLite 640x640 model\n",
        "!wget https://raw.githubusercontent.com/tensorflow/models/master/research/object_detection/data/mscoco_label_map.pbtxt -q\n",
        "\n",
        "# We need to convert the Object Detection API's labelmap into what the Task API needs:\n",
        "# a txt file with one class name on each line from index 0 to N.\n",
        "# The first '0' class indicates the background.\n",
        "# This code assumes COCO detection which has 90 classes, you can write a label\n",
        "# map file for your model if re-trained.\n",
        "_ODT_LABEL_MAP_PATH = 'mscoco_label_map.pbtxt'\n",
        "_TFLITE_LABEL_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/tflite_label_map.txt\"\n",
        "\n",
        "category_index = label_map_util.create_category_index_from_labelmap(\n",
        "    _ODT_LABEL_MAP_PATH)\n",
        "f = open(_TFLITE_LABEL_PATH, 'w')\n",
        "for class_id in range(1, 91):\n",
        "  if class_id not in category_index:\n",
        "    f.write('???\\n')\n",
        "    continue\n",
        "  name = category_index[class_id]['name']\n",
        "  f.write(name+'\\n')\n",
        "f.close()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YJSyXq5Qss9X"
      },
      "source": [
        "Then we'll add the label map and other necessary metadata (e.g. normalization config) to the TFLite model.\n",
        "\n",
        "As the `SSD MobileNet V2 FPNLite 640x640` model take input image with pixel value in the range of [-1..1] ([code](https://github.com/tensorflow/models/blob/b09e75828e2c65ead9e624a5c7afed8d214247aa/research/object_detection/models/ssd_mobilenet_v2_keras_feature_extractor.py#L132)), we need to set `norm_mean = 127.5` and `norm_std = 127.5`. See this [documentation](https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters) for more details on the normalization parameters."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CRQpfDAWsPeK"
      },
      "outputs": [],
      "source": [
        "from tflite_support.metadata_writers import object_detector\n",
        "from tflite_support.metadata_writers import writer_utils\n",
        "\n",
        "_TFLITE_MODEL_WITH_METADATA_PATH = \"ssd_mobilenet_v2_fpnlite_640x640_coco17_tpu-8/model_with_metadata.tflite\"\n",
        "\n",
        "writer = object_detector.MetadataWriter.create_for_inference(\n",
        "    writer_utils.load_file(_TFLITE_MODEL_PATH), input_norm_mean=[127.5], \n",
        "    input_norm_std=[127.5], label_file_paths=[_TFLITE_LABEL_PATH])\n",
        "writer_utils.save_file(writer.populate(), _TFLITE_MODEL_WITH_METADATA_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YFEAjRBdPCQb"
      },
      "source": [
        "Optional: Print out the metadata added to the TFLite model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FT3-38PJsSOt"
      },
      "outputs": [],
      "source": [
        "from tflite_support import metadata\n",
        "\n",
        "displayer = metadata.MetadataDisplayer.with_model_file(_TFLITE_MODEL_WITH_METADATA_PATH)\n",
        "print(\"Metadata populated:\")\n",
        "print(displayer.get_metadata_json())\n",
        "print(\"=============================\")\n",
        "print(\"Associated file(s) populated:\")\n",
        "print(displayer.get_packed_associated_file_list())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l7zVslTRnEHX"
      },
      "source": [
        "The TFLite model now can be integrated into a mobile app using the TFLite Task Library. See the [documentation](https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector) for more details."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Convert TF Object Detection API model to TFLite.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1R4_y-u14YTdvBzhmvC0HQwh3HkcCN2Bd",
          "timestamp": 1623114733432
        },
        {
          "file_id": "1Rey5kAzNQhJ77tsXGjhcAV0UZ6du0Sla",
          "timestamp": 1622897882140
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
